{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.2\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.2\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!pip install --quiet -U lancedb\n",
    "!pip install --quiet gradio transformers torch torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import io\n",
    "import PIL\n",
    "import duckdb\n",
    "import lancedb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## First run setup: Download data and pre-process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<lance.dataset.LanceDataset at 0x3045db590>"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# remove null prompts\n",
    "import lance\n",
    "import pyarrow.compute as pc\n",
    "\n",
    "# download s3://eto-public/datasets/diffusiondb/small_10k.lance to this uri\n",
    "data = lance.dataset(\"~/datasets/rawdata.lance\").to_table()\n",
    "\n",
    "# First data processing and full-text-search index\n",
    "db = lancedb.connect(\"~/datasets/demo\")\n",
    "tbl = db.create_table(\"diffusiondb\", data.filter(~pc.field(\"prompt\").is_null()))\n",
    "tbl = tbl.create_fts_index([\"prompt\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create / Open LanceDB Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "db = lancedb.connect(\"~/datasets/demo\")\n",
    "tbl = db.open_table(\"diffusiondb\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create CLIP embedding function for the text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast\n",
    "\n",
    "MODEL_ID = \"openai/clip-vit-base-patch32\"\n",
    "\n",
    "tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)\n",
    "model = CLIPModel.from_pretrained(MODEL_ID)\n",
    "processor = CLIPProcessor.from_pretrained(MODEL_ID)\n",
    "\n",
    "def embed_func(query):\n",
    "    inputs = tokenizer([query], padding=True, return_tensors=\"pt\")\n",
    "    text_features = model.get_text_features(**inputs)\n",
    "    return text_features.detach().numpy()[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Search functions for Gradio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_image_vectors(query):\n",
    "    emb = embed_func(query)\n",
    "    code = (\n",
    "        \"import lancedb\\n\"\n",
    "        \"db = lancedb.connect('~/datasets/demo')\\n\"\n",
    "        \"tbl = db.open_table('diffusiondb')\\n\\n\"\n",
    "        f\"embedding = embed_func('{query}')\\n\"\n",
    "        \"tbl.search(embedding).limit(9).to_df()\"\n",
    "    )\n",
    "    return (_extract(tbl.search(emb).limit(9).to_df()), code)\n",
    "\n",
    "def find_image_keywords(query):\n",
    "    code = (\n",
    "        \"import lancedb\\n\"\n",
    "        \"db = lancedb.connect('~/datasets/demo')\\n\"\n",
    "        \"tbl = db.open_table('diffusiondb')\\n\\n\"\n",
    "        f\"tbl.search('{query}').limit(9).to_df()\"\n",
    "    )\n",
    "    return (_extract(tbl.search(query).limit(9).to_df()), code)\n",
    "\n",
    "def find_image_sql(query):\n",
    "    code = (\n",
    "        \"import lancedb\\n\"\n",
    "        \"import duckdb\\n\"\n",
    "        \"db = lancedb.connect('~/datasets/demo')\\n\"\n",
    "        \"tbl = db.open_table('diffusiondb')\\n\\n\"\n",
    "        \"diffusiondb = tbl.to_lance()\\n\"\n",
    "        f\"duckdb.sql('{query}').to_df()\"\n",
    "    )    \n",
    "    diffusiondb = tbl.to_lance()\n",
    "    return (_extract(duckdb.sql(query).to_df()), code)\n",
    "\n",
    "def _extract(df):\n",
    "    image_col = \"image\"\n",
    "    return [(PIL.Image.open(io.BytesIO(row[image_col])), row[\"prompt\"]) for _, row in df.iterrows()]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup Gradio interface"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on local URL:  http://127.0.0.1:7881\n",
      "\n",
      "To create a public link, set `share=True` in `launch()`.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"http://127.0.0.1:7881/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gradio as gr\n",
    "\n",
    "\n",
    "with gr.Blocks() as demo:\n",
    "    with gr.Row():\n",
    "        with gr.Tab(\"Embeddings\"):\n",
    "            vector_query = gr.Textbox(value=\"portraits of a person\", show_label=False)\n",
    "            b1 = gr.Button(\"Submit\")\n",
    "        with gr.Tab(\"Keywords\"):\n",
    "            keyword_query = gr.Textbox(value=\"ninja turtle\", show_label=False)\n",
    "            b2 = gr.Button(\"Submit\")\n",
    "        with gr.Tab(\"SQL\"):\n",
    "            sql_query = gr.Textbox(value=\"SELECT * from diffusiondb WHERE image_nsfw >= 2 LIMIT 9\", show_label=False)\n",
    "            b3 = gr.Button(\"Submit\")\n",
    "    with gr.Row():\n",
    "        code = gr.Code(label=\"Code\", language=\"python\")\n",
    "    with gr.Row():\n",
    "        gallery = gr.Gallery(\n",
    "                label=\"Found images\", show_label=False, elem_id=\"gallery\"\n",
    "            ).style(columns=[3], rows=[3], object_fit=\"contain\", height=\"auto\")   \n",
    "        \n",
    "    b1.click(find_image_vectors, inputs=vector_query, outputs=[gallery, code])\n",
    "    b2.click(find_image_keywords, inputs=keyword_query, outputs=[gallery, code])\n",
    "    b3.click(find_image_sql, inputs=sql_query, outputs=[gallery, code])\n",
    "    \n",
    "demo.launch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
